Skip to content

feat(pt_expt): add dp freeze support and dp test tests for .pte models#5302

Merged
wanghan-iapcm merged 6 commits intodeepmodeling:masterfrom
wanghan-iapcm:feat-pt-expt-dpfrztest
Mar 18, 2026
Merged

feat(pt_expt): add dp freeze support and dp test tests for .pte models#5302
wanghan-iapcm merged 6 commits intodeepmodeling:masterfrom
wanghan-iapcm:feat-pt-expt-dpfrztest

Conversation

@wanghan-iapcm
Copy link
Copy Markdown
Collaborator

@wanghan-iapcm wanghan-iapcm commented Mar 10, 2026

Summary

  • Add dp freeze support for the pt_expt backend, enabling checkpoint .pt → exported .pte conversion
  • Add end-to-end tests for both dp freeze and dp test with .pte models

Background

The pt_expt backend can export models to .pte via deserialize_to_file(), and dp test can already load .pte models through the registered DeepEval. However, dp freeze was not
wired up — calling dp freeze -b pt-expt hit RuntimeError: Unsupported command 'freeze'.

Changes

deepmd/pt_expt/entrypoints/main.py

  • Add freeze() function: loads .pt checkpoint → reconstructs model via get_model + ModelWrapper → serializes → exports to .pte via deserialize_to_file
  • Wire freeze command in main() dispatcher with checkpoint directory resolution and .pte default suffix

source/tests/pt_expt/test_dp_freeze.py (new)

  • test_freeze_pte — verify .pte file is created from checkpoint
  • test_freeze_main_dispatcher — test main() CLI dispatcher with freeze command
  • test_freeze_default_suffix — verify non-.pte output suffix is corrected to .pte

source/tests/pt_expt/test_dp_test.py (new)

  • test_dp_test_system — test dp test with -s system path, verify .e.out, .f.out, .v.out outputs
  • test_dp_test_input_json — test dp test with --valid-data JSON input

Test plan

  • python -m pytest source/tests/pt_expt/test_dp_freeze.py -v (3 passed)
  • python -m pytest source/tests/pt_expt/test_dp_test.py -v (2 passed)

Summary by CodeRabbit

  • New Features
    • Added a "freeze" CLI command to convert PyTorch checkpoints into portable .pte model files, with output filename normalization and sensible default naming; multi-task head usage now reports unsupported.
  • Tests
    • Added unit tests for the freeze function and CLI dispatch behavior.
    • Added integration tests validating end-to-end dp_test workflows using frozen models.

Add freeze() function to pt_expt backend that loads a .pt checkpoint,
reconstructs the model, serializes it, and exports to .pte via
deserialize_to_file. Wire the freeze command in the main() CLI dispatcher.

Add separate test files for dp freeze (test_dp_freeze.py) and dp test
(test_dp_test.py) verifying the full freeze-then-test pipeline works
end-to-end with .pte models.
Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 35745c127e

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

Comment thread deepmd/pt_expt/entrypoints/main.py Outdated
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Mar 10, 2026

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

Adds a new "freeze" CLI command and a freeze() function to the pt_expt backend to load PyTorch checkpoints, instantiate/wrap models, load state, and serialize them to .pte. Adds unit tests validating freeze behavior and dp_test integration with frozen models.

Changes

Cohort / File(s) Summary
pt_expt CLI & freeze implementation
deepmd/pt_expt/entrypoints/main.py
Adds freeze(model, output="frozen_model.pte", head=None), resolves checkpoint paths (dir → model.ckpt.pt), reads checkpoint _extra_state and state_dict, guards multi-task head with NotImplementedError, instantiates model via get_model()/ModelWrapper, loads state, and exports .pte. Integrates freeze into main() and normalizes output suffixes.
freeze unit tests
source/tests/pt_expt/test_dp_freeze.py
New tests (TestDPFreezePtExpt) that create a fake checkpoint and assert .pte output for direct freeze() calls, for CLI dispatch (main() with freeze), and for default suffix normalization when given .pth.
dp_test integration tests
source/tests/pt_expt/test_dp_test.py
New tests (TestDPTestPtExpt) that build a model (model_se_e2_a), save a checkpoint, freeze it to .pte, run dp_test with filesystem and JSON inputs, and assert expected detail outputs (.e.out, .f.out, .v.out).

Sequence Diagram

sequenceDiagram
    participant CLI as main()
    participant Freeze as freeze()
    participant ModelBuilder as get_model()/ModelWrapper
    participant Storage as Filesystem

    CLI->>Freeze: freeze(checkpoint_path, output_file, head)
    Freeze->>Freeze: resolve checkpoint path (dir → model.ckpt.pt)
    Freeze->>Storage: read checkpoint (state_dict, _extra_state)
    Freeze->>ModelBuilder: get_model() → instantiate model
    Freeze->>ModelBuilder: wrap model, load state_dict
    ModelBuilder->>Freeze: wrapped model ready
    Freeze->>Storage: deserialize_to_file(wrapped_model, output.pte)
    Storage->>CLI: saved path / success log
Loading

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~25 minutes

Suggested labels

enhancement

Suggested reviewers

  • iProzd
  • njzjz
🚥 Pre-merge checks | ✅ 3
✅ Passed checks (3 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title accurately summarizes the main changes: adding dp freeze support and dp test tests for .pte models in the pt_expt backend.
Docstring Coverage ✅ Passed Docstring coverage is 100.00% which is sufficient. The required threshold is 80.00%.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
📝 Coding Plan
  • Generate coding plan for human review comments

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (2)
source/tests/pt_expt/test_dp_test.py (1)

27-45: Consider extracting shared test fixtures.

The model_se_e2_a configuration and checkpoint creation pattern are duplicated in test_dp_freeze.py. Consider extracting these to a shared conftest or helper module to reduce duplication.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@source/tests/pt_expt/test_dp_test.py` around lines 27 - 45, The model
configuration dictionary model_se_e2_a and the checkpoint creation logic
duplicated between test_dp_test.py and test_dp_freeze.py should be extracted
into a shared pytest fixture or helper function (e.g., in conftest.py or a
test_helpers module); create a fixture named model_se_e2_a that returns the dict
and a helper fixture/function (e.g., make_checkpoint or checkpoint_fixture) that
encapsulates the checkpoint creation pattern, then update both tests to accept
those fixtures instead of redefining the dict/checkpoint code so duplication is
removed and maintenance is centralized.
deepmd/pt_expt/entrypoints/main.py (1)

256-257: Minor: .pt2 suffix check might be undocumented.

The code accepts both .pte and .pt2 suffixes, but the docstring and default only mention .pte. Consider documenting .pt2 if it's intentionally supported, or remove it if not needed.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@deepmd/pt_expt/entrypoints/main.py` around lines 256 - 257, The FLAGS.output
handling accepts both ".pte" and ".pt2" but only ".pte" is documented; decide
whether ".pt2" is intentional and then update code accordingly: if intended, add
".pt2" to the module/docstring and the FLAGS.output help/default text (where
FLAGS is defined) and update any docs/tests to mention ".pt2"; otherwise remove
".pt2" from the tuple in the conditional so FLAGS.output only normalizes to
".pte". Ensure changes reference FLAGS.output and the suffix check in main.py.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@deepmd/pt_expt/entrypoints/main.py`:
- Around line 250-253: The code reading the checkpoint name using
(checkpoint_path / "checkpoint").read_text() assigns latest_ckpt_file with
possible trailing newline/whitespace which breaks FLAGS.model path construction;
update the read to strip whitespace (e.g., call .strip() on the result) before
using checkpoint_path.joinpath and set FLAGS.model =
str(checkpoint_path.joinpath(latest_ckpt_file.strip())), ensuring you reference
FLAGS.checkpoint_folder, checkpoint_path, latest_ckpt_file and FLAGS.model when
making the change.

---

Nitpick comments:
In `@deepmd/pt_expt/entrypoints/main.py`:
- Around line 256-257: The FLAGS.output handling accepts both ".pte" and ".pt2"
but only ".pte" is documented; decide whether ".pt2" is intentional and then
update code accordingly: if intended, add ".pt2" to the module/docstring and the
FLAGS.output help/default text (where FLAGS is defined) and update any
docs/tests to mention ".pt2"; otherwise remove ".pt2" from the tuple in the
conditional so FLAGS.output only normalizes to ".pte". Ensure changes reference
FLAGS.output and the suffix check in main.py.

In `@source/tests/pt_expt/test_dp_test.py`:
- Around line 27-45: The model configuration dictionary model_se_e2_a and the
checkpoint creation logic duplicated between test_dp_test.py and
test_dp_freeze.py should be extracted into a shared pytest fixture or helper
function (e.g., in conftest.py or a test_helpers module); create a fixture named
model_se_e2_a that returns the dict and a helper fixture/function (e.g.,
make_checkpoint or checkpoint_fixture) that encapsulates the checkpoint creation
pattern, then update both tests to accept those fixtures instead of redefining
the dict/checkpoint code so duplication is removed and maintenance is
centralized.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

Run ID: e8b4749c-4b84-4f85-a10c-9f671ace384a

📥 Commits

Reviewing files that changed from the base of the PR and between 24e54bf and 35745c1.

📒 Files selected for processing (3)
  • deepmd/pt_expt/entrypoints/main.py
  • source/tests/pt_expt/test_dp_freeze.py
  • source/tests/pt_expt/test_dp_test.py

Comment thread deepmd/pt_expt/entrypoints/main.py Outdated
@codecov
Copy link
Copy Markdown

codecov Bot commented Mar 10, 2026

Codecov Report

❌ Patch coverage is 77.77778% with 8 lines in your changes missing coverage. Please review.
✅ Project coverage is 82.30%. Comparing base (24e54bf) to head (33fe75e).
⚠️ Report is 11 commits behind head on master.

Files with missing lines Patch % Lines
deepmd/pt_expt/entrypoints/main.py 77.77% 8 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##           master    #5302      +/-   ##
==========================================
- Coverage   82.32%   82.30%   -0.02%     
==========================================
  Files         768      775       +7     
  Lines       77098    77664     +566     
  Branches     3659     3675      +16     
==========================================
+ Hits        63469    63924     +455     
- Misses      12458    12567     +109     
- Partials     1171     1173       +2     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.
  • 📦 JS Bundle Analysis: Save yourself from yourself by tracking and limiting bundle sizes in JS merges.

Comment thread deepmd/pt_expt/entrypoints/main.py Outdated
pt_expt training saves checkpoints as model.ckpt-{step}.pt with a
model.ckpt.pt symlink, not a "checkpoint" text file. The previous
code was copied from the pt backend which uses a different format.
@wanghan-iapcm wanghan-iapcm requested a review from iProzd March 15, 2026 15:03
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@deepmd/pt_expt/entrypoints/main.py`:
- Around line 194-198: The checkpoint loading code assumes a nested schema and
directly accesses state_dict["_extra_state"]["model_params"], which can raise
uninformative KeyError; update the block after torch.load(model, ...) to
explicitly validate the schema: ensure state_dict is a dict (after handling the
optional "model" wrapper), that "_extra_state" is present and is a dict, and
that "model_params" exists inside it, and if any check fails raise a clear
ValueError mentioning the expected keys and the actual top-level keys (include
the model identifier variable name), so callers get an actionable error instead
of a raw KeyError when accessing model_params.
- Around line 261-263: The code currently assigns FLAGS.model =
FLAGS.checkpoint_folder without validating that the path exists, deferring
errors to torch.load; update the else branch that sets FLAGS.model (using
FLAGS.checkpoint_folder) to immediately check the filesystem: if the path is an
existing directory or an existing file accept it, otherwise raise a clear CLI
error (e.g., call parser.error or sys.exit with a descriptive message) so the
user fails fast before reaching torch.load; reference FLAGS.model and
FLAGS.checkpoint_folder in the check and ensure the error message mentions the
invalid checkpoint path.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

Run ID: f707d44b-ebab-4908-9e8f-4aee84d87b86

📥 Commits

Reviewing files that changed from the base of the PR and between 35745c1 and 11f96ba.

📒 Files selected for processing (1)
  • deepmd/pt_expt/entrypoints/main.py

Comment thread deepmd/pt_expt/entrypoints/main.py
Comment thread deepmd/pt_expt/entrypoints/main.py
@wanghan-iapcm wanghan-iapcm enabled auto-merge March 16, 2026 11:32
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (1)
deepmd/pt_expt/entrypoints/main.py (1)

194-199: ⚠️ Potential issue | 🟡 Minor

Validate checkpoint root type before nested access.

state_dict.get("_extra_state") assumes a dict-like object. If a malformed/incompatible checkpoint is loaded, this path raises an opaque AttributeError instead of a clear CLI error.

Proposed fix
     state_dict = torch.load(model, map_location=DEVICE, weights_only=True)
     if "model" in state_dict:
         state_dict = state_dict["model"]

+    if not isinstance(state_dict, dict):
+        raise ValueError(
+            f"Unsupported checkpoint format at '{model}': "
+            f"expected dict-like state_dict, got {type(state_dict).__name__}."
+        )
+
     extra_state = state_dict.get("_extra_state")
     if not isinstance(extra_state, dict) or "model_params" not in extra_state:
         raise ValueError(
             f"Unsupported checkpoint format at '{model}': missing "
             "'_extra_state.model_params' in model state dict."
         )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@deepmd/pt_expt/entrypoints/main.py` around lines 194 - 199, The code assumes
state_dict is a mapping and does state_dict.get("_extra_state") which can raise
AttributeError for malformed checkpoints; update the loading logic around
torch.load(..., weights_only=True) and the subsequent state_dict handling
(variable state_dict and extra_state) to first verify state_dict is a dict-like
object (e.g., isinstance(state_dict, dict)) and only then attempt to read
"_extra_state" and "model_params"; if the type check fails or "_extra_state" /
"model_params" are missing, raise a clear CLI-friendly error/exception with
context (including the checkpoint identifier) instead of letting AttributeError
propagate.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@deepmd/pt_expt/entrypoints/main.py`:
- Around line 206-209: The current guard only raises NotImplementedError when
both head is provided and "model_dict" is in model_params, but we should block
any multi-task checkpoint regardless of head; update the check in main.py (the
block containing head and model_params) to raise NotImplementedError whenever
"model_dict" is present in model_params (remove the head requirement) so
multi-task checkpoints are unconditionally rejected by the pt_expt backend.

---

Duplicate comments:
In `@deepmd/pt_expt/entrypoints/main.py`:
- Around line 194-199: The code assumes state_dict is a mapping and does
state_dict.get("_extra_state") which can raise AttributeError for malformed
checkpoints; update the loading logic around torch.load(..., weights_only=True)
and the subsequent state_dict handling (variable state_dict and extra_state) to
first verify state_dict is a dict-like object (e.g., isinstance(state_dict,
dict)) and only then attempt to read "_extra_state" and "model_params"; if the
type check fails or "_extra_state" / "model_params" are missing, raise a clear
CLI-friendly error/exception with context (including the checkpoint identifier)
instead of letting AttributeError propagate.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

Run ID: 7dd197e6-754f-4544-958a-d2e0930300c4

📥 Commits

Reviewing files that changed from the base of the PR and between e4298d3 and 33fe75e.

📒 Files selected for processing (1)
  • deepmd/pt_expt/entrypoints/main.py

Comment thread deepmd/pt_expt/entrypoints/main.py Outdated
A multi-task checkpoint without --head would silently fall through
and produce wrong results. Remove the head check so any multi-task
checkpoint raises NotImplementedError regardless of --head flag.
@wanghan-iapcm wanghan-iapcm added this pull request to the merge queue Mar 18, 2026
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@deepmd/pt_expt/entrypoints/main.py`:
- Around line 275-276: The output normalization should only accept the .pte
extension: change the condition that checks FLAGS.output so it only allows
".pte" (i.e., replace the endswith((".pte", ".pt2")) check with
endswith(".pte")), and if it doesn't, set FLAGS.output =
str(Path(FLAGS.output).with_suffix(".pte")); update the logic around
FLAGS.output and Path.with_suffix to drop support for ".pt2" so downstream
loaders (serialize_from_file, DeepEval, torch.export.load) always get a .pte
file.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

Run ID: a2074e19-8dbe-4a03-979b-0be4d7ba4ba2

📥 Commits

Reviewing files that changed from the base of the PR and between 33fe75e and 8479016.

📒 Files selected for processing (1)
  • deepmd/pt_expt/entrypoints/main.py

Comment on lines +275 to +276
if not FLAGS.output.endswith((".pte", ".pt2")):
FLAGS.output = str(Path(FLAGS.output).with_suffix(".pte"))
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
# Verify whether `.pt2` is supported wherever `.pte` is consumed.
# Expected: extension routing/loading code should explicitly accept both,
# otherwise prefer normalizing freeze output to `.pte` only.

rg -nP --type=py -C3 '\.pte\b|\.pt2\b'
rg -nP --type=py -C4 'freeze|test|infer|load.*model|endswith\(\(.*\.pte|\.pt2'

Repository: deepmodeling/deepmd-kit

Length of output: 50381


🏁 Script executed:

# Search for code that loads .pte or .pt2 files in inference/test paths
rg -n --type=py '\.pte|\.pt2' --max-count=100 deepmd/pt_expt/infer/ deepmd/pt_expt/utils/serialization.py deepmd/backend/pt_expt.py

# Check the DeepEval class that loads models
rg -n --type=py -B2 -A5 'class DeepEval|def.*__init__' deepmd/pt_expt/infer/deep_eval.py | head -50

# Search for any explicit .pt2 handling
rg -n --type=py 'pt2' deepmd/

Repository: deepmodeling/deepmd-kit

Length of output: 1948


🏁 Script executed:

# Check if .pt2 is used anywhere else in tests or loading code
rg -n 'pt2' source/tests/ deepmd/pt_expt/

# Check if there's any model loading with file extension handling
rg -n --type=py 'load.*pte|pte.*load|suffix' deepmd/pt_expt/infer/ deepmd/pt_expt/utils/

# Look for file operations that might accept both extensions
rg -n --type=py 'endswith|suffix|extension' deepmd/pt_expt/ -A2 -B2 | head -80

Repository: deepmodeling/deepmd-kit

Length of output: 4004


🏁 Script executed:

# Final verification: check if pt2 is referenced in any way outside main.py
rg -r 'pt2' --type=py . | grep -v 'test_hybrid_muon\|test_adamuon\|descrpt2r4'

# Verify the backend suffixes definition once more
cat -n deepmd/backend/pt_expt.py | sed -n '40,50p'

Repository: deepmodeling/deepmd-kit

Length of output: 50382


🏁 Script executed:

# Get the backend file content to see the suffix definition
cat -n deepmd/backend/pt_expt.py | sed -n '40,50p'

# Search more specifically for code that reads .pte files  
rg -n --type=py 'torch.export.load|load.*pte' deepmd/

Repository: deepmodeling/deepmd-kit

Length of output: 654


Tighten output normalization to accept only .pte extension.

Lines 275-276 preserve both .pte and .pt2 suffixes, but the backend and all loaders (serialize_from_file, DeepEval, torch.export.load) only support .pte. If users provide .pt2 output, the resulting files cannot be loaded downstream, creating broken artifacts.

Suggested fix
-        if not FLAGS.output.endswith((".pte", ".pt2")):
+        if not FLAGS.output.endswith(".pte"):
             FLAGS.output = str(Path(FLAGS.output).with_suffix(".pte"))
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@deepmd/pt_expt/entrypoints/main.py` around lines 275 - 276, The output
normalization should only accept the .pte extension: change the condition that
checks FLAGS.output so it only allows ".pte" (i.e., replace the
endswith((".pte", ".pt2")) check with endswith(".pte")), and if it doesn't, set
FLAGS.output = str(Path(FLAGS.output).with_suffix(".pte")); update the logic
around FLAGS.output and Path.with_suffix to drop support for ".pt2" so
downstream loaders (serialize_from_file, DeepEval, torch.export.load) always get
a .pte file.

Merged via the queue into deepmodeling:master with commit 3ab3779 Mar 18, 2026
3 of 4 checks passed
@wanghan-iapcm wanghan-iapcm deleted the feat-pt-expt-dpfrztest branch March 18, 2026 05:26
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants